#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Dec 31 2021
Plot total deposition and difference from GEOS-Chem simulations
@author: arifeinberg
"""

import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import cartopy.crs as ccrs
import scipy.io as sio
import datetime
import xarray as xr
from helper_functions import open_Hg,  ds_sel_yr, annual_avg

#%% load data files
run_old = '0023'
run_new = '0022'
Year = 2015 # year to analyze from simulations 

fn_old_wdep = '../GEOS-Chem_runs/run' + run_old + '/OutputDir/GEOSChem.WetLossTotal.alltime_m.nc4'
fn_new_wdep = '../GEOS-Chem_runs/run' + run_new + '/OutputDir/GEOSChem.WetLossTotal.alltime_m.nc4'
fn_old_ddep = '../GEOS-Chem_runs/run' + run_old + '/OutputDir/GEOSChem.DryDep.alltime_m.nc4'
fn_new_ddep = '../GEOS-Chem_runs/run' + run_new + '/OutputDir/GEOSChem.DryDep.alltime_m.nc4'
fn_old_ocean = '../GEOS-Chem_runs/run' + run_old + '/OutputDir/GEOSChem.MercuryOcean.alltime_m.nc4'
fn_new_ocean = '../GEOS-Chem_runs/run' + run_new + '/OutputDir/GEOSChem.MercuryOcean.alltime_m.nc4'
fn_old_chem = '../GEOS-Chem_runs/run' + run_old + '/OutputDir/GEOSChem.MercuryChem.alltime_m.nc4'
fn_new_chem = '../GEOS-Chem_runs/run' + run_new + '/OutputDir/GEOSChem.MercuryChem.alltime_m.nc4'

ds1_wdep, ds2_wdep = open_Hg(fn_old_wdep, fn_new_wdep) # load wet deposition data
ds1_ddep, ds2_ddep = open_Hg(fn_old_ddep, fn_new_ddep) # load dry deposition data
ds1_ocean, ds2_ocean = open_Hg(fn_old_ocean, fn_new_ocean) # load ocean data
ds1_chem, ds2_chem = open_Hg(fn_old_chem, fn_new_chem) # load chemistry data
#%% Load fluxes relevant to deposition

# Wet deposition fluxes
OLD_Hg_totwdep_yr = ds_sel_yr(ds1_wdep, 'WetLossTot_Hg', Year) # kg/s
NEW_Hg_totwdep_yr = ds_sel_yr(ds2_wdep, 'WetLossTot_Hg', Year) # kg/s

OLD_Hg_totwdep = annual_avg(OLD_Hg_totwdep_yr) # calculate annual average
NEW_Hg_totwdep = annual_avg(NEW_Hg_totwdep_yr) # calculate annual average

# Dry deposition fluxes
OLD_Hg0_ddep_yr = ds_sel_yr(ds1_ddep, 'DryDep_Hg0', Year) # molec/cm2/s
NEW_Hg0_ddep_yr = ds_sel_yr(ds2_ddep, 'DryDep_Hg0', Year) # molec/cm2/s
OLD_Hg2_ddep_yr = ds_sel_yr(ds1_ddep, 'DryDep_Hg2', Year) # molec/cm2/s
NEW_Hg2_ddep_yr = ds_sel_yr(ds2_ddep, 'DryDep_Hg2', Year) # molec/cm2/s
OLD_HgP_ddep_yr = ds_sel_yr(ds1_ddep, 'DryDep_HgP', Year) # molec/cm2/s
NEW_HgP_ddep_yr = ds_sel_yr(ds2_ddep, 'DryDep_HgP', Year) # molec/cm2/s

OLD_total_ddep_yr = OLD_Hg0_ddep_yr + OLD_Hg2_ddep_yr + OLD_HgP_ddep_yr # take sum
NEW_total_ddep_yr = NEW_Hg0_ddep_yr + NEW_Hg2_ddep_yr + NEW_HgP_ddep_yr # take sum

OLD_total_ddep = annual_avg(OLD_total_ddep_yr) # annual average
NEW_total_ddep = annual_avg(NEW_total_ddep_yr) # annual average

# Ocean uptake of Hg0
OLD_gross_uptake_yr = ds_sel_yr(ds1_ocean, 'FluxHg0fromAirToOcean', Year) # kg/s
NEW_gross_uptake_yr = ds_sel_yr(ds2_ocean, 'FluxHg0fromAirToOcean', Year) # kg/s

OLD_gross_uptake = annual_avg(OLD_gross_uptake_yr) # calculate annual average
NEW_gross_uptake = annual_avg(NEW_gross_uptake_yr) # calculate annual average

# Sea salt uptake flux
OLD_Hg2_salt_yr = ds_sel_yr(ds1_chem, 'LossHg2bySeaSalt_v', Year) # kg/s
NEW_Hg2_salt_yr = ds_sel_yr(ds2_chem, 'LossHg2bySeaSalt_v', Year) # kg/s

OLD_Hg2_salt = annual_avg(OLD_Hg2_salt_yr) # calculate annual average
NEW_Hg2_salt = annual_avg(NEW_Hg2_salt_yr) # calculate annual average

# Load grid cell area for unit conversion of model
fn_gbox = '../gbox_areas/GEOSChem_2x25_gboxarea.nc'
ds_gbox = xr.open_dataset(fn_gbox)
gbox_GC = ds_gbox.cell_area

# convert from molec/cm2/s to kg/s
s_in_yr = 365.2425 * 24 * 3600 # s in one year
g_kg = 1e3 # kg in g 
cm2_m2 = 1e4 # cm^2 in m^2
MW_Hg = 200.59 # g mol^-1
avo = 6.02e23 # avogadro number molec mol^-1

unit_conv = 1. / avo * MW_Hg * cm2_m2 / g_kg * gbox_GC

OLD_total_ddep = OLD_total_ddep * unit_conv
NEW_total_ddep = NEW_total_ddep * unit_conv

OLD_Hg0_u = annual_avg(OLD_Hg0_ddep_yr) * unit_conv * s_in_yr
NEW_Hg0_u = annual_avg(NEW_Hg0_ddep_yr) * unit_conv * s_in_yr

# calculate total deposition in kg/s
OLD_total_dep = OLD_Hg_totwdep + OLD_total_ddep + OLD_gross_uptake + OLD_Hg2_salt
NEW_total_dep = NEW_Hg_totwdep + NEW_total_ddep + NEW_gross_uptake + NEW_Hg2_salt

# convert from kg/s to kg/yr
OLD_total_dep = OLD_total_dep * s_in_yr
NEW_total_dep = NEW_total_dep * s_in_yr

# Find the absolute difference between the reference and new model.
Abs_diff = NEW_total_dep - OLD_total_dep 
Abs_diff_Hg0 = NEW_Hg0_u - OLD_Hg0_u
# Find the absolute maximum value of the absolute difference. 
Abs_MaxVal= np.max(np.abs(Abs_diff.values))

# Find the percent difference of the models.  
Perc_diff = (Abs_diff / OLD_total_dep )*100

# Find the absolute maximum value of the percent  difference. 
Perc_diff_non_nan = Perc_diff.values[~np.isnan(Perc_diff.values)] # Only check non_nan values for limit
Perc_MaxVal= np.max(np.abs(Perc_diff_non_nan)) # for plotting limits

# Set limit to MaxVal as 100%, since can't have negative numbers
Perc_MaxVal = min(Perc_MaxVal, 100)
# define your scale, with white at zero
vmin = -60 
vmax = 20
norm = colors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

# Plot the four graphs as subplots.
f1,  axes = plt.subplots(1, 1, figsize=[12,9],
                               subplot_kw=dict(projection=ccrs.EckertIV()),
                               gridspec_kw=dict(hspace=0.2, wspace=0.1))

#%% Plot total Hg deposition difference plot 
im = Perc_diff.plot.pcolormesh(x='lon',y='lat',ax=axes,transform=ccrs.PlateCarree(),
                           rasterized = True, cmap='RdBu_r',add_colorbar=False,
                      vmin=vmin, vmax=vmax, norm=norm)
cbar = plt.colorbar(im, ax=axes, fraction=0.046, pad=0.04, extend='min',
                    orientation='horizontal')
cbar.set_label('Percent difference (%)', fontsize=20)
cbar.ax.tick_params(labelsize=19)

# Add a title.
axes.set_title("Total deposition (relative difference)", fontsize=30)
# Add the coastlines. 
axes.coastlines()

f1.savefig('Figures/Fig5b_SA_on_sav_noant.pdf',bbox_inches = 'tight')

#plot.savefig('../Figures/total_dep_change_savannize.pdf',bbox_inches = 'tight')
#%% Calculate difference in Hg deposition over ocean
# Load land mask
fn_land = '../landmask_geos2x25.nc'
ds_l = xr.open_dataset(fn_land)
mask_l = ds_l.sea

fn_sea = '../seamask_geos2x25.nc'
ds_s = xr.open_dataset(fn_sea)
mask_s = ds_s.topo

# calculate sum only over sea
print("Diff over sea")
print((Abs_diff * mask_l).sum()/1000.)
print("Diff over land")
print((Abs_diff * mask_s).sum()/1000.)

# calculate as percent of total deposition to ocean
print("Diff as percent of total sea deposition")
print((Abs_diff * mask_l).sum() / (OLD_total_dep * mask_l).sum())

# select Amazon coordinate bounds
lat_min = -34
lat_max = 14
lon_min = -82
lon_max = -33

# get latitude and longitude for amazon area
fn_ols2 = 'misc_Data/Olson_2001_Land_Type_Masks.2_25.generic.nc'
ds_ols2 = xr.open_dataset(fn_ols2)
Olson_landtype = ds_ols2.to_array().squeeze()
fn_ols1 = 'misc_Data/Olson_2001_Drydep_Inputs.nc'
ds_ols1 = xr.open_dataset(fn_ols1)
lon = ds_ols2.lon # longitude
lat = ds_ols2.lat # latitude
#%% restrict to amazon area the Olson landtype
Olson_landtype_A = Olson_landtype[:,(lat<=lat_max)&(lat>=lat_min),(lon<=lon_max)&(lon>=lon_min)]
Abs_diff_A = Abs_diff_Hg0[(lat<=lat_max)&(lat>=lat_min),(lon<=lon_max)&(lon>=lon_min)]
OLD_Hg0_A = OLD_Hg0_u[(lat<=lat_max)&(lat>=lat_min),(lon<=lon_max)&(lon>=lon_min)]
NEW_Hg0_A = NEW_Hg0_u[(lat<=lat_max)&(lat>=lat_min),(lon<=lon_max)&(lon>=lon_min)]

IDEP = ds_ols1.IDEP.values # Mapping index: Olson land type ID to drydep ID
IOLSON_6 = np.asarray(np.where(IDEP==6)).flatten()

ind_rainforest = np.asarray(np.where((sum(Olson_landtype_A[IOLSON_6,:,:]) > 0.1)))

Abs_diff_A_v = Abs_diff_A.values
OLD_Hg0_A_v = OLD_Hg0_A.values
NEW_Hg0_A_v = NEW_Hg0_A.values

total_diff_A = np.sum(Abs_diff_A_v[ind_rainforest[0,:],ind_rainforest[1,:]])
total_OLD_Hg0_A = np.sum(OLD_Hg0_A_v[ind_rainforest[0,:],ind_rainforest[1,:]])
total_NEW_Hg0_A = np.sum(NEW_Hg0_A_v[ind_rainforest[0,:],ind_rainforest[1,:]])
print("Old amazon dep")
print(total_OLD_Hg0_A/1000)
print("New amazon dep")
print(total_NEW_Hg0_A/1000)
